from emodel_generalisation.mcmc import load_chains
import seaborn as sns
import shap
from tqdm import tqdm
import pandas as pd
import numpy as np
from emodel_generalisation.mcmc import plot_corner
import matplotlib.pyplot as plt
from xgboost import XGBRegressor
from sklearn.metrics import mean_absolute_error
from sklearn.model_selection import RepeatedKFold


def _get_shap_feature_importance(shap_values):
    """From a list of shap values per folds, compute the global shap feature importance."""
    mean_shap_values = np.mean(shap_values, axis=0)
    if len(np.shape(mean_shap_values)) > 2:
        global_mean_shap_values = np.mean(mean_shap_values, axis=0)
        mean_shap_values = list(mean_shap_values)
    else:
        global_mean_shap_values = mean_shap_values

    shap_feature_importance = np.mean(abs(global_mean_shap_values), axis=0)
    return mean_shap_values, shap_feature_importance


def train(X, y, n_splits=2, n_repeats=1, ax=None):
    model = XGBRegressor()
    folds = RepeatedKFold(n_splits=n_splits, n_repeats=n_repeats, random_state=42)
    acc_scores = []
    shap_values = []
    for indices in tqdm(folds.split(X, y=y), total=n_splits * n_repeats):
        train_index, val_index = indices
        model.fit(X.iloc[train_index], y.iloc[train_index])
        df = pd.DataFrame()
        df["true"] = y.iloc[val_index]["Step_ReboundBurst_burst.soma.v.all_burst_number"].to_list()
        df["predicted"] = model.predict(X.iloc[val_index])
        print(df)
        sns.boxplot(data=df, x="true", y="predicted", ax=ax, native_scale=True, showfliers=False)
        acc_score = mean_absolute_error(y.iloc[val_index], model.predict(X.iloc[val_index]))
        acc_scores.append(acc_score)
        explainer = shap.TreeExplainer(model)
        shap_value = explainer.shap_values(X)
        shap_values.append(shap_value)

        return acc_scores, shap_values, []


def plot(X, shap_values, shap_interaction_values, name):
    mean_shap_values, shap_feature_importance = _get_shap_feature_importance(shap_values)
    print(shap_feature_importance, X.columns)
    plt.figure()
    shap.summary_plot(
        mean_shap_values,
        X,
        plot_type="bar",
        max_display=5,
        show=False,
        plot_size=(8, 3),
    )
    plt.tight_layout()
    plt.savefig(f"bar_shap_{name}.pdf")
    plt.close()

    plt.figure()
    shap.summary_plot(
        mean_shap_values,
        X,
        plot_type="dot",
        max_display=5,
        show=False,
        color_bar_label="parameter value",
        plot_size=(8, 3),
    )
    plt.tight_layout()
    plt.savefig(f"dot_shap_{name}.pdf")
    plt.close()
    """
    shap.summary_plot(shap_interaction_values, X)
    plt.tight_layout()
    plt.savefig(f"interaction_summary_plot_{name}.pdf")
    plt.close()

    shap.dependence_plot(
        ("gk_max_iA.basal", "gk_max_iA.basal"),
        shap_interaction_values,
        X,
        display_features=X,
    )
    plt.tight_layout()
    plt.savefig(f"dependence_plot_1_{name}.pdf")
    plt.close()

    shap.dependence_plot(
        (
            "gbar_ican.basal",
            "gk_max_iA.basal",
        ),
        shap_interaction_values,
        X,
        display_features=X,
    )
    plt.tight_layout()
    plt.savefig(f"dependence_plot_2_{name}.pdf")
    plt.close()
    """


if __name__ == "__main__":
    mcmc_df = load_chains(
        "../../../mcmc_run/run_df.csv",
        base_path="../../../mcmc_run",
    )
    mcmc_df = mcmc_df[mcmc_df.cost < 2.0]
    mask_runaway = (mcmc_df["features"]["Step_ReboundBurst_burst.soma.v.burst_runaway"] < 0.05) & (
        mcmc_df["features"]["Step_ReboundBurst_burst.soma.v.time_to_last_spike"] > 20000
    )
    plt.figure()
    plt.hist(
        mcmc_df["features"]["Step_ReboundBurst_burst.soma.v.all_burst_number"],
        bins=100,
        range=[0, 100],
    )
    plt.savefig("hist_burst_number.pdf")

    _mcmc_df = mcmc_df[~mask_runaway]
    _mcmc_df = _mcmc_df.reset_index(drop=True)

    X = _mcmc_df["normalized_parameters"]
    y = pd.DataFrame(_mcmc_df["features"]["Step_ReboundBurst_burst.soma.v.all_burst_number"])

    fig, ax = plt.subplots(figsize=(6, 4))

    acc_scores, shap_values, shap_interaction_values = train(X, y, ax=ax)
    print(acc_scores)
    plot(X, shap_values, shap_interaction_values, "all_burst_number_non_runaway")

    _mcmc_df = mcmc_df[mask_runaway]
    _mcmc_df = _mcmc_df.reset_index(drop=True)

    X = _mcmc_df["normalized_parameters"]
    y = pd.DataFrame(_mcmc_df["features"]["Step_ReboundBurst_burst.soma.v.all_burst_number"])
    acc_scores, shap_values, shap_interaction_values = train(X, y, ax=ax)
    print(acc_scores)
    plot(X, shap_values, shap_interaction_values, "all_burst_number_runaway")

    ax.plot([0, 90], [0, 90], color="k")
    ax.axis([0, 90, 0, 90])
    ax.set_ylabel("predicted burst number")

    ax1 = ax.twinx()
    ax1.set_ylabel("# models")
    ax1.set_xlabel("true burst number")
    plt.hist(
        mcmc_df["features"]["Step_ReboundBurst_burst.soma.v.all_burst_number"],
        bins=100,
        range=[0, 100],
        histtype="step",
    )
    plt.tight_layout()
    fig.savefig("accuracy.pdf")

    _mcmc_df = mcmc_df[~mask_runaway]
    _mcmc_df = _mcmc_df.drop(
        columns=[
            c
            for c in mcmc_df.columns
            if (c[0] in "normalized_parameters")
            and (
                c[1]
                not in [
                    "g_pas.all",
                    "gcabar_it2.basal",
                    #    "constant.distribution_increase",
                ]
            )
        ]
    )

    plot_corner(
        _mcmc_df,
        feature=("features", "Step_ReboundBurst_burst.soma.v.all_burst_number"),
        filename="burst_number_non_runaway.pdf",
    )

    _mcmc_df = mcmc_df[mask_runaway]
    _mcmc_df = _mcmc_df.drop(
        columns=[
            c
            for c in mcmc_df.columns
            if (c[0] in "normalized_parameters")
            and (
                c[1]
                not in [
                    "gbar_ican.basal",
                    "shift_it2.somadend",
                    #"constant.distribution_increase",
                ]
            )
        ]
    )

    plot_corner(
        _mcmc_df,
        feature=("features", "Step_ReboundBurst_burst.soma.v.all_burst_number"),
        filename="burst_number_runaway.pdf",
    )
